library(reticulate)
use_python("/home/nealpsmith/.conda/envs/sc_analysis/bin/python")

For figure 5, we wanted to focus in on the mononuclear phagocytes (MNP). We subsetted our main data object to just the MNP clusters. From there we wanted to find the heterogeneity among the MNP and find out which MNP clusters were enriched in AA or AC.

First, we isolated just the MNP clusters from our original data object. Next, we wanted to determine the optimal number of clusters for our MNP. Briefly, similar to Reyes et al., we quantified cluster stability at many Leiden resolutions to determine which one should be used for downstream analyses. First, data was clustered at many different resolutions ranging from 0.3 to 1.9. Next, we iteratively subsampled 90% of the data (20 iterations) and at each iteration, we made a new clustering solution with the subsampled data and used the Adjusted Rand Index (ARI) to compare it to the clustering solution from the full data.

import pegasus as pg
import numpy as np
import pandas as pd
import concurrent.futures
from sklearn.metrics.cluster import adjusted_rand_score
import random
import time
import leidenalg
import concurrent.futures
import os
from pegasus.tools import construct_graph
from scipy.sparse import csr_matrix

# Use Rand index to determine leiden resolution to use
def rand_index_plot(
        W,  # adata.uns['W_' + rep] or adata.uns['neighbors']
        resamp_perc=0.9,
        resolutions=(0.3, 0.5, 0.7, 0.9, 1.1, 1.3, 1.5, 1.7, 1.9),
        max_workers=25,
        n_samples=25,
        random_state=0
    ):
    assert isinstance(W, csr_matrix)
    rand_indx_dict = {}
    n_cells = W.shape[0]
    resamp_size = round(n_cells * resamp_perc)

    for resolution in resolutions:

        true_class = leiden(W, resolution, random_state)

        with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(_collect_samples, W, resolution, n_cells, resamp_size, true_class, random_state)
                       for i in range(n_samples)]
            rand_list = [f.result() for f in futures]

        rand_indx_dict[str(resolution)] = rand_list
        print("Finished {res}".format(res=resolution))
    return rand_indx_dict
def leiden(W, resolution, random_state=0):

    start = time.perf_counter()

    G = construct_graph(W)
    partition_type = leidenalg.RBConfigurationVertexPartition
    partition = leidenalg.find_partition(
        G,
        partition_type,
        seed=random_state,
        weights="weight",
        resolution_parameter=resolution,
        n_iterations=-1,
    )

    labels = np.array([str(x + 1) for x in partition.membership])

    end = time.perf_counter()
    n_clusters = len(set(labels))
    return pd.Series(labels)

def _collect_samples(W, resolution, n_cells, resamp_size, true_class, random_state=0):
    samp_indx = random.sample(range(n_cells), resamp_size)
    samp_data = W[samp_indx][:, samp_indx]
    true_class = true_class[samp_indx]
    new_class = leiden(samp_data, resolution, random_state)
    return adjusted_rand_score(true_class, new_class)
import matplotlib.pyplot as plt
import seaborn as sns

mnp_harmonized = pg.read_input("/home/nealpsmith/projects/medoff/data/anndata_for_publication/mnp_harmonized.h5ad")

# rand_indx_dict = rand_index_plot(W = mnp_harmonized.uns["W_pca_harmony"],
#                                       resolutions  = [0.3, 0.5, 0.7, 0.9, 1.1, 1.3, 1.5, 1.7, 1.9],
#                                       n_samples = 2)
#
# plot_df = pd.DataFrame(rand_indx_dict).T
# plot_df = plot_df.reset_index()
# plot_df = plot_df.melt(id_vars="index")
# plot_df.to_csv(os.path.join(file_path(), "data", "ari_plots", "myeloid_harmonized_ARI.csv"))
## 2022-11-28 14:11:35,805 - pegasus - INFO - Time spent on 'read_input' = 4.20s.
plot_df = pd.read_csv("/home/nealpsmith/projects/medoff/data/ari_plots/myeloid_harmonized_ARI.csv")
fig, ax = plt.subplots(1)
_ = sns.boxplot(x="index", y="value", data=plot_df, ax = ax)
for box in ax.artists:
    box.set_facecolor("grey")
# ax.artists[6].set_facecolor("red") # The one we chose!
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='major', labelsize=15)
_ = ax.set_ylabel("Adjusted Rand index", size = 20)
_ = ax.set_xlabel("leiden resolution", size = 20)
_ = plt.axhline(y = 0.9, color = "black", linestyle = "--")
fig.tight_layout()
fig

Based on this rand index approach, we can see that a leiden resolution of 1.5 is the highest resolution where the median ARI of all iterations was > 0.9. Given this, we started our clustering at this resolution.

import scanpy as sc
import matplotlib.colors as clr
colormap = clr.LinearSegmentedColormap.from_list('gene_cmap', ["#d3d3d3" ,'#482cc7'], N=200)

# pg.leiden(mnp_harmonized, resolution = 1.5, rep = "pca_harmony", random_state = 2)

sc.pl.umap(mnp_harmonized, color = ["leiden_labels"], legend_loc = "on data")

After visually assessing many genes and through expert annotation, we could see that there was a small cluster of CLEC9A+ cDC1 cells among cluster 3. Given we know that cDC1s have a distinct biological function, we manually subsetted these cells to represent their own cluster. We can defend this choice by scoring cells using a cDC1 gene set that was published by Villani et al. Cell 2017.


### Want to look at DC1 gene set (we think we see a small cluster of them) ###
def score_cells(data, gene_set):
    # Get rid of genes that aren't in data
    gene_set = [gene for gene in gene_set if gene in data.var_names]
    print(gene_set)
    # Limit the data to just those genes
    dat = data[:, gene_set].X
    dat = dat.toarray()
    mean = dat.mean(axis=0)
    var = dat.var(axis=0)
    std = np.sqrt(var)

    with np.errstate(divide="ignore", invalid="ignore"):
        dat = (dat - mean) / std
    dat[dat < -5] = -5
    dat[dat > 5] = 5

    scores = dat.mean(axis=1)
    return (scores)

dc1_genes = ["CLEC9A", "C1ORF54", "HLA-DPA1", "CADM1", "CAMK2D", "CPVL", "HLA-DPB2", "WDFY4", "CPNE3", "IDO1",
            "HLA-DPB1", "LOC645638", "HLA-DOB", "HLA-DQB1", "HLA-DQB", "CLNK", "CSRP1", "SNX3", "ZNF366",
             "KIAA1598", "NDRG2", "ENPP1", "RGS10", "AX747832", "CYB5R3", "ID2", "XCR1", "FAM190A", "ASAP1",
             "SLAMF8", "CD59", "DHRS3", "GCET2", "FNBP1", "TMEM14A", "NET1", "BTLA", "BCL6", "FLT3", "ADAM28", "SLAMF7",
             "BATF3", "LGALS2", "VAC14", "PPA1", "APOL3", "C1ORF21", "CCND1", "ANPEP", "ELOVL5", "NCALD", "ACTN1",
             "PIK3CB", "HAVCR2", "GYPC", "TLR10", "ASB2", "KIF16B", "LRRC18", "DST", "DENND1B", "DNASE1L3", "SLC24A4",
             "VAV3", "THBD", "NAV1", "GSTM4", "TRERF1", "B3GNT7", "LACC1", "LMNA", "PTK2", "IDO2", "MTERFD3", "CD93",
             "DPP4", "SLC9A9", "FCRL6", "PDLIM7", "CYP2E1", "PDE4DIP", "LIMA1", "CTTNBP2NL", "PPM1M", "OSBPL3", "PLCD1",
             "CD38", "EHD4", "ACSS2", "LOC541471", "FUCA1", "SNX22", "APOL1", "DUSP10", "FAM160A2", "INF2", "DUSP2",
             "PALM2", "RAB11FIP4", "DSE", "FAM135A", "KCNK6", "PPM1H", "PAFAH1B3", "PDLIM1", "TGM2", "SCARF1", "CD40",
             "STX3", "WHAMMP3", "PRELID2", "PQLC2"]
dc1_genes = [gene for gene in dc1_genes if gene in mnp_harmonized.var_names]
mnp_harmonized.obs["dc1_score"] = score_cells(mnp_harmonized, dc1_genes)

sc.pl.umap(mnp_harmonized, color="dc1_score", cmap = colormap)

Additionally, after expert annotation, we did not think cluster 10 represented any unique set of cells and seemed to share all major markers with cluster 1. Therefore, we refined our clustering by combining clusters 1 and 10 and manually segregating the cDC1 cells. Below is the final clustering solution. Cluster numbers were also adjusted to have them ordered by the phenotypic similarities (the order they fall in the marker heatmap below)

sc.pl.umap(mnp_harmonized, color = "new_clusters", legend_loc = "on data")

Marker genes

First we can look at marker genes by AUROC. The motivation here is to determine for each cluster which specific genes are good classifiers for cluster membership. These stats were calculated using the Pegasus de_analysis function.

# pg.de_analysis(mnp_harmonized, cluster = "new_clusters", auc = True,
#                n_jobs = len(set(mnp_harmonized.obs["new_clusters"])))

top_auc = {}
top_genes = {}
for clust in sorted(set(mnp_harmonized.obs["new_clusters"]), key = int) :
    df_dict = {"auc": mnp_harmonized.varm["de_res"]["auroc:{clust}".format(clust=clust)]}
    df = pd.DataFrame(df_dict, index=mnp_harmonized.var.index)
    df = df.sort_values(by=["auc"], ascending=False)
    auc_df = df.iloc[0:50]
    genes = auc_df[auc_df["auc"] >= 0.75].index.values
    top_genes[clust] = genes

top_gene_df = pd.DataFrame(dict([(k,pd.Series(v)) for k,v in top_genes.items() ]))
top_gene_df = top_gene_df.rename(columns = {clust : "cluster_{clust}".format(clust=clust) for clust in top_genes.keys()})
top_gene_df = top_gene_df.replace(np.nan, "")

library(knitr)
kable(reticulate::py$top_gene_df, caption = "genes with AUC > 0.75")
genes with AUC > 0.75
cluster_1 cluster_2 cluster_3 cluster_4 cluster_5 cluster_6 cluster_7 cluster_8 cluster_9 cluster_10 cluster_11 cluster_12 cluster_13 cluster_14
GBP1 CTSB NAMPT APOC1 FTL H2AFZ VSIR A2M GZMB CD83 CLEC9A CD1C PPP1R14A
STAT1 CSTB SRGN FTL FTH1 STMN1 FGL2 CD74 JCHAIN BIRC3 WDFY4 CD1E SEPT6
C15orf48 CTSD NEAT1 FABP4 APOC1 HMGB1 AOAH HLA-DQA1 TCF4 LAMP3 SNX3 CIITA LTB
WARS LGALS3 PLAUR MARCO FABP4 TUBA1B CPVL HLA-DPA1 UGCG NFKB1 IRF8 TLR10 C12orf75
LAP3 CD63 SLC2A3 GCHFR CTSD PTMA MS4A6A C1QC ITM2C DAPP1 NAAA CKLF CDH1
GBP4 FABP5 SAT1 C1QA CSTB PCLAF SORL1 HLA-DQB1 IGKC REL ID2 PPA1 NKG7
GBP5 CTSL IL1RN C1QB MARCO HMGB2 TMEM176B HLA-DRB1 PPP1R14B CCR7 CPNE3 HLA-DQA1 AXL
GBP2 BRI3 FTH1 FTH1 LGALS3 TYMS MNDA C1QA BCL11A BASP1 SHTN1 PPP1R14A MYL12A
SOD2 GPNMB VEGFA APOE APOE TUBB CLEC10A C1QB IRF7 IL7R CST3 ACTG1 CD2
CXCL10 PSAP CEBPB MS4A7 DUT F13A1 ITM2B CCDC50 CST7 CADM1 CD74 ARL4C
PARP14 SDCBP MCL1 MSR1 CKS1B GGT5 HLA-DMB TSPAN13 CCL22 CD226 HLA-DPA1 TCF4
EPSTI1 CEBPB BTG1 TFRC HMGN2 AIF1 HLA-DRA MZB1 CFLAR TAP1 SLC38A1 SULF2
TYMP ATP6V1F TIMP1 LGALS3 PCNA VASH1 HLA-DPB1 IRF8 GPR157 HLA-DPA1 DENND1B LGMN
NAMPT FTH1 METRNL CTSD DEK JAML AXL SMPD3 MARCKS RAB7B LIMD2 HINT1
TAP1 EMP3 DUSP1 NUPR1 TK1 CD93 MS4A6A IL3RA NR4A3 ENPP1 KCNK6 BID
BCL2A1 ANXA2 OLR1 SERPING1 GAPDH NAIP FGL2 C12orf75 IL4I1 CPVL HLA-DPB1 ATF5
PSME2 ANXA5 ATP13A3 TSPO MKI67 GIMAP7 CLIC3 BTG1 BATF3 NDRG2 PLAC8
APOL6 GSTO1 SGK1 SCD SIGLEC10 PLAC8 GPR183 HLA-DPB1 HLA-DQB1 RPS23
SGK1 CD44 MGST3 CEBPA MALAT1 ETV3 XCR1 LRRK1 TUBA1A
SLC11A1 SOD2 CSTB COTL1 SEC61B CSF2RA IDO1 GDI2 LRRFIP1
ASAH1 FOSL2 AC026369.3 APP EIF1 HLA-DQA1 RUNX3 CSF2RB
PDXK HIF1A CES1 SELL CMTM6 CD74 SPIB
PLA2G7 H3F3B LTA4H SERPINF1 CDKN1A BASP1 RAB11FIP1
FTL LITAF GLRX LILRA4 PNRC1 CCND1 MRPS6
AQP9 GK CYP27A1 MAP1A EEF1A1 CLNK CRYBG1
LAMP1 C1QC DERL3 HLA-DQB1 PLP2
GLUL GPNMB ALOX5AP C1orf54 SYNGR2
S100A11 CD52 SCT RAB11FIP1 S100A10
CD9 FABP5 SPCS1 ASAP1 NFAT5
CD44 PPARG HERPUD1 SLAMF7 ACTG1
SH3BGRL3 TXNIP SPIB ICAM3 CCND3
MGAT1 LGALS3BP CYB561A3 FNBP1 SUSD1
VIM FAM89A LINC00996 ARF6 ALOX5AP
PLIN2 S100A11 LTB CST7 SEC61B
FBP1 CFD PLD4 RAB32 UGCG
DAB2 ALDH2 CLEC4C HLA-DRB1 INPP4A
CD68 FBP1 STMN1 ACTG1 RASSF5
CD163 S100A6 CD2AP DENND1B
C15orf48 VSIG4 SLC15A4 RGCC
S100A10 ACP5 PMEPA1 BCL6
NCF2 FCGR3A CCDC186 HLA-DOB
CCL2 CD81 FCHSD2 HLA-DRB5
PCOLCE2 ZFAT RGS10
ALDH1A1 MPEG1 DNASE1L3
TYROBP PLP2 DST
TREM1 OFD1 ZNF366
PLIN2 NUCB2 LGALS2
CTSL SELENOS CCSER1
PLA2G16 RPS12 HLA-DRA
IFI6 SEL1L3 CAMK2D

We can see from the above AUROC genes, that we don't have a strong enough signal from some clusters to get a good sense of their phenotype solely on that. So we can also find markers using an OVA pseudobulk approach. To do this, we first created a psedudobulk matrix by summing the UMI counts across cells for each unique cluster/sample combination, creating a matrix of n genes x (n samples * n clusters). Using this matrix with DESeq2, For each cluster, we used an input model gene ~ in_clust where in_clust is a factor with two levels indicating if the sample was in or not in the cluster being tested. Genes with an FDR < 5% were considered marker genes.

# import neals_python_functions as nealsucks
# # Read in the raw count data
# raw_data = pg.read_input("/home/nealpsmith/projects/medoff/data/all_data.h5sc")
# raw_data = raw_data[mnp_harmonized.obs_names]
# raw_data = raw_data[:, mnp_harmonized.var_names]
# raw_data.obs = mnp_harmonized.obs[["leiden_labels", "Channel"]]
#
# # Create the matrix
# raw_sum_dict = {}
# cell_num_dict = {}
# for samp in set(raw_data.obs["Channel"]):
#     for clust in set(raw_data.obs["leiden_labels"]):
#         dat = raw_data[(raw_data.obs["Channel"] == samp) & (raw_data.obs["leiden_labels"] == clust)]
#         if len(dat) == 0:
#             continue
#         cell_num_dict["samp_{samp}_{clust}".format(samp=samp, clust=clust)] = len(dat)
#         count_sum = np.array(dat.X.sum(axis=0)).flatten()
#         raw_sum_dict["samp_{samp}_{clust}".format(samp=samp, clust=clust)] = count_sum
#
# count_mtx = pd.DataFrame(raw_sum_dict, index=raw_data.var.index.values)
#
# meta_df = pd.DataFrame(cell_num_dict, index=["n_cells"]).T
# meta_df["cluster"] = [name.split("_")[-1] for name in meta_df.index.values]
# meta_df["sample"] = [name.split("_")[-2] for name in meta_df.index.values]
# meta_df["phenotype"] = [name.split("_")[-3] for name in meta_df.index.values]
# meta_df["id"] = ["_".join(name.split("_")[0:2]) for name in meta_df.index.values]
#
# clust_df = pd.DataFrame(index=count_mtx.index)
# # Lets run pseudobulk on clusters
# for clust in set(mnp_harmonized.obs["leiden_labels"]):
#     print(clust)
#     meta_temp = meta_df.copy()
#     meta_temp["isclust"] = ["yes" if cluster == clust else "no" for cluster in meta_temp["cluster"]]
#
#     assert all(meta_temp.index.values == count_mtx.columns)
#     # Run DESeq2
#     deseq = nealsucks.analysis.deseq2.py_DESeq2(count_matrix=count_mtx, design_matrix=meta_temp,
#                                                 design_formula="~ isclust")
#     deseq.run_deseq()
#     res = deseq.get_deseq_result()
#     clust_df = clust_df.join(res[["pvalue"]].rename(
#         columns={"pvalue": "pseudobulk_p_val:{clust}".format(clust=clust)}))

de_res = mnp_harmonized.varm["de_res"]
# de_res = pd.DataFrame(de_res, index=res.index)
# de_res = de_res.join(clust_df)
de_res = pd.DataFrame(de_res, index = mnp_harmonized.var_names)
de_res = de_res.fillna(0)
names = [name for name in de_res.columns if name.startswith("pseudobulk_p_val")]

import statsmodels.stats.multitest as stats
for name in names :
    clust = name.split(":")[1]
    de_res["pseudobulk_q_val:{clust}".format(clust = clust)] = stats.fdrcorrection(de_res[name])[1]

de_res = de_res.to_records(index=False)
mnp_harmonized.varm["de_res"] = de_res

top_genes = {}
for clust in sorted(set(mnp_harmonized.obs["leiden_labels"]), key = int) :
    df_dict = {"auc": mnp_harmonized.varm["de_res"]["auroc:{clust}".format(clust=clust)],
               "pseudo_q" : mnp_harmonized.varm["de_res"]["pseudobulk_q_val:{clust}".format(clust = clust)],
               "pseudo_p" : mnp_harmonized.varm["de_res"]["pseudobulk_p_val:{clust}".format(clust = clust)],
               "pseudo_log_fc" : mnp_harmonized.varm["de_res"]["pseudobulk_log_fold_change:{clust}".format(clust = clust)],
               "percent" : mnp_harmonized.varm["de_res"]["percentage:{clust}".format(clust = clust)]}
    df = pd.DataFrame(df_dict, index=mnp_harmonized.var.index)
    # Lets limit to genes where at least 20% cells express it
    df = df[df["percent"] > 20]
    df = df.sort_values(by=["auc"], ascending=False)
    auc_df = df.iloc[0:50]
    # genes = df.index.values
    # Get top 50 genes (first by AUC, then by pseudobulk)
    genes = auc_df[auc_df["auc"] >= 0.75].index.values

    n_from_pseudo = 50 - len(genes)
    if n_from_pseudo > 0 :
        # Dont want to repeat genes
        pseudobulk = df.drop(genes)
        pseudobulk = pseudobulk[(pseudobulk["pseudo_q"] < 0.05)]
        pseudobulk = pseudobulk.sort_values(by = "pseudo_log_fc", ascending = False).iloc[0:n_from_pseudo,:].index.values
        pseudobulk = [name for name in pseudobulk if name not in genes]
        genes = np.concatenate((genes, pseudobulk))

    print("Cluster {clust}: {length}".format(clust = clust, length = len(genes)))
    top_genes[clust] = genes

top_gene_df = pd.DataFrame(dict([(k,pd.Series(v)) for k,v in top_genes.items() ]))
top_gene_df = top_gene_df.rename(columns = {clust : "cluster_{clust}".format(clust=clust) for clust in top_genes.keys()})
top_gene_df = top_gene_df.replace(np.nan, "")

kable(reticulate::py$top_gene_df, caption = "genes with AUC> 0.75 or pseudo q < 0.05")
genes with AUC> 0.75 or pseudo q < 0.05
cluster_1 cluster_2 cluster_3 cluster_4 cluster_5 cluster_6 cluster_7 cluster_8 cluster_9 cluster_10 cluster_11 cluster_12 cluster_13 cluster_14
GBP1 CTSB NAMPT APOC1 FTL SCGB1A1 H2AFZ VSIR A2M GZMB CD83 CLEC9A CD1C PPP1R14A
STAT1 CSTB SRGN FTL FTH1 SCGB3A1 STMN1 FGL2 CD74 JCHAIN BIRC3 WDFY4 CD1E SEPT6
C15orf48 CTSD NEAT1 FABP4 APOC1 PKIB HMGB1 AOAH HLA-DQA1 TCF4 LAMP3 SNX3 CIITA LTB
WARS LGALS3 PLAUR MARCO FABP4 FCER1A TUBA1B CPVL HLA-DPA1 UGCG NFKB1 IRF8 TLR10 C12orf75
LAP3 CD63 SLC2A3 GCHFR CTSD HLA-DQB2 PTMA MS4A6A C1QC ITM2C DAPP1 NAAA CKLF CDH1
GBP4 FABP5 SAT1 C1QA CSTB ABI3 PCLAF SORL1 HLA-DQB1 IGKC REL ID2 PPA1 NKG7
GBP5 CTSL IL1RN C1QB MARCO TMEM14C HMGB2 TMEM176B HLA-DRB1 PPP1R14B CCR7 CPNE3 HLA-DQA1 AXL
GBP2 BRI3 FTH1 FTH1 LGALS3 YWHAH TYMS MNDA C1QA BCL11A BASP1 SHTN1 PPP1R14A MYL12A
SOD2 GPNMB VEGFA APOE APOE IL18 TUBB CLEC10A C1QB IRF7 IL7R CST3 ACTG1 CD2
CXCL10 PSAP CEBPB MS4A7 NUPR1 RAB32 DUT F13A1 ITM2B CCDC50 CST7 CADM1 CD74 ARL4C
PARP14 SDCBP MCL1 MSR1 AC026369.3 HLA-DPB1 CKS1B GGT5 HLA-DMB TSPAN13 CCL22 CD226 HLA-DPA1 TCF4
EPSTI1 CEBPB BTG1 TFRC MT2A RGS10 HMGN2 AIF1 HLA-DRA MZB1 CFLAR TAP1 SLC38A1 SULF2
TYMP ATP6V1F TIMP1 LGALS3 GCHFR CIB1 PCNA VASH1 HLA-DPB1 IRF8 GPR157 HLA-DPA1 DENND1B LGMN
NAMPT FTH1 METRNL CTSD SMIM25 LST1 DEK JAML AXL SMPD3 MARCKS RAB7B LIMD2 HINT1
TAP1 EMP3 DUSP1 NUPR1 SLC11A1 HLA-DPA1 TK1 CD93 MS4A6A IL3RA NR4A3 ENPP1 KCNK6 BID
BCL2A1 ANXA2 OLR1 SERPING1 FAM89A CLTB GAPDH NAIP FGL2 C12orf75 IL4I1 CPVL HLA-DPB1 ATF5
PSME2 ANXA5 ATP13A3 TSPO TGM2 SMIM26 MKI67 GIMAP7 SLC40A1 CLIC3 BTG1 BATF3 NDRG2 PLAC8
APOL6 GSTO1 SGK1 SCD PLIN2 ECHS1 TOP2A SIGLEC10 CTTNBP2 PLAC8 GPR183 HLA-DPB1 HLA-DQB1 RPS23
CXCL11 SGK1 CD44 MGST3 TFRC RPL26 UBE2C CEBPA CLEC4F MALAT1 ETV3 XCR1 LRRK1 TUBA1A
APOBEC3A SLC11A1 SOD2 CSTB ALDH1A1 CLNS1A ASPM COTL1 LPAR6 SEC61B CSF2RA IDO1 GDI2 LRRFIP1
ANKRD22 ASAH1 FOSL2 AC026369.3 ABCG1 SNHG8 CENPF S100A12 GPR155 APP EIF1 HLA-DQA1 RUNX3 CSF2RB
RSAD2 PDXK HIF1A CES1 PPARG GDI2 CENPM CDA ADAM28 SELL CMTM6 CD74 ENHO SPIB
GCH1 PLA2G7 H3F3B LTA4H MSR1 ALKBH7 TPX2 FCN1 C3 SERPINF1 CDKN1A BASP1 CD207 RAB11FIP1
ETV7 FTL LITAF GLRX SERPING1 COMMD6 FAM111B GIMAP1 AL138899.1 LILRA4 PNRC1 CCND1 GIPR MRPS6
TNFSF10 AQP9 GK CYP27A1 GLUL RPL31 RRM2 ASGR1 USP53 MAP1A EEF1A1 CLNK PIK3R6 CRYBG1
IFIT3 LAMP1 EREG C1QC CYB5A RPL34 GTSE1 CCR2 GPR34 DERL3 CCL19 HLA-DQB1 SUSD3 PLP2
IFIT2 GLUL THBS1 GPNMB HDDC2 FAM162A CDKN3 CXCR2 CCL4L2 ALOX5AP LAD1 C1orf54 RAB33A SYNGR2
FPR2 S100A11 OSM CD52 MS4A7 ATP5MC2 CEP55 ASGR2 SRGAP1 SCT FSCN1 RAB11FIP1 SYTL1 S100A10
HAPLN3 CD9 CD300E FABP5 S100A9 RPL38 CDK1 GIMAP6 NPFFR1 SPCS1 EBI3 ASAP1 LINC02381 NFAT5
ISG15 CD44 VCAN PPARG BLVRB NDUFS8 ZWINT GIMAP8 CCL3L1 HERPUD1 SLCO5A1 SLAMF7 KCNMB1 ACTG1
VAMP5 SH3BGRL3 HBEGF TXNIP NCF4 RTRAF BIRC5 CD1D GUCY1A1 SPIB ANKRD33B ICAM3 FILIP1L CCND3
IL1B MGAT1 SDS LGALS3BP CHCHD10 RPL32 AURKB CEACAM4 ARHGEF12 CYB561A3 MIR155HG FNBP1 CACNA2D3 SUSD1
HERC5 VIM IL1B FAM89A GYPC BTF3 CCNB2 DUSP6 CX3CR1 LINC00996 HMSD ARF6 ABCB4 ALOX5AP
RARRES3 PLIN2 SLC16A10 S100A11 SCD RPL36 ESCO2 RCBTB2 MERTK LTB DUSP4 CST7 PPP1R16A SEC61B
OAS2 FBP1 MMP19 CFD HSPA1A SEC11A KIF11 PTGFRN PCNX2 PLD4 TNFRSF9 RAB32 SCARF1 UGCG
APOL3 DAB2 CXCL8 ALDH2 VSIG4 EIF3F NUSAP1 GIMAP4 UNC5B CLEC4C IL32 HLA-DRB1 ITGB2-AS1 INPP4A
FCGR1B CD68 ZNF331 FBP1 CSTA RPL12 NDC80 CFP FCER1A STMN1 TRAF1 ACTG1 AL118516.1 RASSF5
FCGR1A CD163 MXD1 S100A6 FBP1 UQCR10 CCNA2 CSF3R MAF CD2AP GADD45A DENND1B HOMER2 CD5
CLEC4E C15orf48 S100A12 VSIG4 CEBPA RPL39 CCNB1 MAP3K3 STAB1 SLC15A4 TBC1D4 RGCC PCNX2 UPK3A
IFI44 S100A10 SLC25A37 ACP5 STXBP2 RPL29 CDCA3 NFAM1 MMP14 PMEPA1 KIF2A BCL6 NFATC2 GPR153
IFI44L NCF2 CXCL2 FCGR3A CD59 SLC25A5 UHRF1 AC096667.1 DNASE1L3 CCDC186 SOCS2 HLA-DOB RAB11FIP4 IGLON5
IFITM3 CCL2 AREG CD81 LRP1 ATP5F1C BUB1B AGTRAP GIMAP7 FCHSD2 CD80 HLA-DRB5 PAOX ABCA6
XAF1 SPP1 FCN1 PCOLCE2 SLCO2B1 RAP1A CLSPN SH3PXD2B LTC4S ZFAT RAMP1 RGS10 FAM89B GADD45A
PSTPIP2 RNASE1 INSIG1 ALDH1A1 TNFSF12 RPS24 CENPU KCNE3 GATM MPEG1 PTGIR DNASE1L3 GOLGA8N MYOM1
P2RX7 EMP1 NR4A1 TYROBP AKR1B1 HIGD2A NUF2 IRAK3 SRGAP3 PLP2 POGLUT1 DST AC009093.2 SOX4
IFIH1 PLPP3 VPS37B TREM1 GLRX RPL35 CENPK CATSPER1 EGR1 OFD1 SLC7A11 ZNF366 LPAR5 AC124319.2
CD274 OTOA C5AR1 PLIN2 ALOX5 SUMO3 ATAD2 TLR8 QPRT NUCB2 MARCKSL1 LGALS2 PAK1 ACSM3
FPR1 SDC2 PPIF CTSL PPDPF RPS15 MAD2L1 TPCN1 SLCO2B1 SELENOS DUSP5 CCSER1 TMEM273 GPR171
PIM1 NPL SEMA6B PLA2G16 NCF2 TPT1 CENPA FHL3 HLA-DQB2 RPS12 CCL17 HLA-DRA GSE1 DAPK2
DDX60 GSDME PTGS2 IFI6 HNMT UQCRB PTTG1 CAMK1 FILIP1L SEL1L3 G0S2 CAMK2D PON2 ECE1

Now with the AUROC and OVA marker genes, we can visualize the markers with a heatmap. First, we looked at the data with a heatmap where both the rows and columns were clustered.

from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl
heatmap_genes = []
repeated_genes = [] # Get genes that are not unique, do not want to annotate them
for key in top_genes.keys() :
    for gene in top_genes[key] :
        if gene not in heatmap_genes :
            heatmap_genes.append(gene)
        else :
            repeated_genes.append(gene)

# Get the genes for annotation: top markers that are not in repeated genes
annot_genes = {}
for clust in top_genes.keys() :
    non_rep_genes = [gene for gene in top_genes[clust] if gene not in repeated_genes and not gene.startswith("RP")]
    annot_genes[clust] = non_rep_genes

# Write out the annotation genes for the heatmap (making with ComplexHeatmap)
annot_genes = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in annot_genes.items() ]))
annot_genes = annot_genes.rename(columns = {clust : "cluster_{clust}".format(clust=clust) for clust in annot_genes.columns})
# Lets add the colors for each cluster from the UMAP
clust_cols = dict(zip(sorted(set(mnp_harmonized.obs["new_clusters"]), key = int),
                      mnp_harmonized.uns["new_clusters_colors"]))
clust_cols = pd.DataFrame(clust_cols,
                          index = ["col"]).rename(columns = dict(zip(clust_cols.keys(),
                                                                     ["cluster_{clust}".format(clust = clust) for clust
                                                                      in clust_cols.keys()])))

annot_genes = annot_genes.append(clust_cols)

# Also need to add mean gene counts
# Get the mean gene counts for sidebar
gene_val_list = []
gene_val_dict = {}
for clust in sorted(set(mnp_harmonized.obs["new_clusters"]), key = int) :
    gene_vals = mnp_harmonized.obs["n_genes"][mnp_harmonized.obs["new_clusters"] == clust]

    mean = np.mean(gene_vals)
    gene_val_list.append(mean)
    gene_val_dict[clust] = mean

# Append these mean gene counts to the dataframe
annot_genes = annot_genes.append(pd.DataFrame(gene_val_dict,
                          index = ["mean_genes"]).rename(columns = dict(zip(gene_val_dict.keys(),
                                                                     ["cluster_{clust}".format(clust = clust) for clust
                                                                      in gene_val_dict.keys()]))))


# Get the mean expression of the top genes from each cluster
de_df = {"mean_log_{clust}".format(clust = clust) : mnp_harmonized.varm["de_res"]["mean_logExpr:{clust}".format(clust = clust)] for clust in sorted(set(mnp_harmonized.obs["new_clusters"]), key = int)}
de_df = pd.DataFrame(de_df, index = mnp_harmonized.var.index)

heatmap_df = de_df.loc[heatmap_genes]


colors = sns.color_palette("ch:2.5,-.2,dark=.2", n_colors = len(gene_val_list)).as_hex()
# Put the gene values in order lowest to highest
sorted_cols = sorted(gene_val_list)

fig, ax = plt.subplots(1, 1, figsize = (10, 10))
divider = make_axes_locatable(ax)
axDivY = divider.append_axes( 'right', size=0.2, pad= 0.1)
axDivY2 = divider.append_axes( 'right', size=0.2, pad= 0.2)
axDivY3 = divider.append_axes( 'right', size=0.2, pad= 0.2)
axDivY4 = divider.append_axes( 'top', size=0.2, pad= 0.2)

ax1 = sns.clustermap(heatmap_df, method = "ward", row_cluster =True, col_cluster =True, z_score = 0, cmap = "vlag")
col_order = np.array([name.split("_")[-1] for name in ax1.data2d.columns])
index = [sorted_cols.index(gene_val_dict[clust]) for clust in col_order]
plt.close()
ax1 = sns.heatmap(ax1.data2d, cmap = "vlag", ax = ax, cbar_ax = axDivY)
ax2 = axDivY2.imshow(np.array([[min(gene_val_list), max(gene_val_list)]]), cmap = mpl.colors.ListedColormap(list(colors)),
                     interpolation = "nearest", aspect = "auto")
axDivY2.set_axis_off()
axDivY2.set_visible(False)
_ = plt.colorbar(ax2, cax = axDivY3)
_ = axDivY3.set_title("n_genes")
ax3 = axDivY4.imshow(np.array([index]),cmap=mpl.colors.ListedColormap(list(colors)),
              interpolation="nearest", aspect="auto")
axDivY4.set_axis_off()
_ = plt.title("top genes for every cluster")
plt.show()

To make things more readable, we also made a heatmap where we kept the columns clustered such that phenotypically similar clusters were grouped together, but manually ordered the rows.

n_heatmap_genes = {}
heatmap_genes = []
for key in col_order :
    cnt = 0
    for gene in top_genes[key] :
        if gene not in heatmap_genes :
            heatmap_genes.append(gene)
            cnt+=1
    n_heatmap_genes[key] = cnt

n_heatmap_genes = pd.DataFrame(n_heatmap_genes, index = ["n_genes"]).rename(columns = dict(zip(n_heatmap_genes.keys(),
                                                                                               ["cluster_{clust}".format(clust = clust) for
                                                                                                clust in n_heatmap_genes.keys()])))
# Add number of genes in the heatmap for each clusters
annot_genes = annot_genes.append(n_heatmap_genes)
annot_genes = annot_genes.reset_index()
annot_genes = annot_genes.fillna('')


# Get the mean expression of the top genes from each cluster
de_df = {"mean_log_{clust}".format(clust = clust) : mnp_harmonized.varm["de_res"]["mean_logExpr:{clust}".format(clust = clust)] for clust in sorted(set(mnp_harmonized.obs["new_clusters"]), key = int)}
de_df = pd.DataFrame(de_df, index = mnp_harmonized.var.index)

heatmap_df = de_df.loc[heatmap_genes]

# Get the mean gene counts for sidebar
gene_val_list = []
gene_val_dict = {}
for clust in sorted(set(mnp_harmonized.obs["new_clusters"]), key = int) :
    gene_vals = mnp_harmonized.obs["n_genes"][mnp_harmonized.obs["new_clusters"] == clust]

    mean = np.mean(gene_vals)
    gene_val_list.append(mean)
    gene_val_dict[clust] = mean

colors = sns.color_palette("ch:2.5,-.2,dark=.2", n_colors = len(gene_val_list)).as_hex()
# Put the gene values in order lowest to highest
sorted_cols = sorted(gene_val_list)

fig, ax = plt.subplots(1, 1, figsize = (10, 10))
divider = make_axes_locatable(ax)
axDivY = divider.append_axes( 'right', size=0.2, pad= 0.1)
axDivY2 = divider.append_axes( 'right', size=0.2, pad= 0.2)
axDivY3 = divider.append_axes( 'right', size=0.2, pad= 0.2)
axDivY4 = divider.append_axes( 'top', size=0.2, pad= 0.2)

# color_label_list =[random.randint(0,14) for i in range(14)]
ax1 = sns.clustermap(heatmap_df, method = "ward", row_cluster =False, col_cluster =True, z_score = 0, cmap = "vlag")
col_order = np.array([name.split("_")[-1] for name in ax1.data2d.columns])
index = [sorted_cols.index(gene_val_dict[clust]) for clust in col_order]
plt.close()

heatmap_carpet = ax1.data2d

ax1 = sns.heatmap(ax1.data2d, cmap = "vlag", ax = ax, cbar_ax = axDivY)
ax2 = axDivY2.imshow(np.array([[min(gene_val_list), max(gene_val_list)]]), cmap = mpl.colors.ListedColormap(list(colors)),
                     interpolation = "nearest", aspect = "auto")
axDivY2.set_axis_off()
axDivY2.set_visible(False)
_ = plt.colorbar(ax2, cax = axDivY3)
_ = axDivY3.set_title("n_genes")
ax3 = axDivY4.imshow(np.array([index]),cmap=mpl.colors.ListedColormap(list(colors)),
              interpolation="nearest", aspect="auto")
axDivY4.set_axis_off()
_ = plt.title("top genes for every cluster")
plt.show()

Finally, we wanted to make a publication-ready figure using the wonderful ComplexHeatmap package, where we can add some annotations for each cluster and add spaces between clusters to make it even more readable.

library(ComplexHeatmap)
library(tidyverse)
library(magrittr)
library(circlize)

heatmap_data <- reticulate::py$heatmap_carpet
annotation_info <- reticulate::py$annot_genes
rownames(annotation_info) <- annotation_info$index
annotation_info$index <- NULL


for (c in colnames(annotation_info)){
  annotation_info[[c]] <- unlist(annotation_info[[c]])
}
# Change the column names to be cleaner
colnames(heatmap_data) <- paste("Cluster", unlist(strsplit(colnames(heatmap_data), "_"))[3*(1:length(colnames(heatmap_data)))], sep = " ")

# Make column names consistent with heatmap data
colnames(annotation_info) <- sapply(str_replace(colnames(annotation_info), "_", " "), str_to_title)

# Lets just get the genes
annotation_genes <- unique(as.character(unlist(annotation_info[1:5,])))
annotation_genes <- annotation_genes[annotation_genes != ""]

# Now lets organize the color info that will be used for annotations
col_info = annotation_info %>%
  t() %>%
  as.data.frame() %>%
  dplyr::select(-mean_genes) %>%
  rownames_to_column(var = "cluster") %>%
  reshape2::melt(id.vars = c("cluster", "col")) %>%
  select(-variable)

# Get the gene colors
gene_cols = c()
for (gene in annotation_genes){
  color = as.character(filter(col_info, value == gene)["col"][[1]])
  gene_cols = c(gene_cols, color)
}

# Get the cluster colors
clust_cols <- c()
for (clust in colnames(heatmap_data)){
  color <- col_info %>%
    dplyr::select(cluster, col) %>%
    distinct() %>%
    filter(cluster == clust)
  clust_cols <- c(clust_cols, as.character(color$col))
}

mean_genes <- annotation_info["mean_genes",] %>%
  mutate_each(funs(as.numeric(as.character(.)))) %>%
  select(colnames(heatmap_data)) # To order them like they will be ordered in the heatmap (same as how GEX data was read in)

gene_col_fun <- colorRamp2(c(min(mean_genes), max(mean_genes)), c("#1d111d", "#bbe7c8"))
gene_bar <-  HeatmapAnnotation("mean # genes" = as.numeric(mean_genes), col = list("mean # genes" = gene_col_fun), show_legend = FALSE)
gene_lgd <- Legend(col_fun = gene_col_fun, title = "# genes", legend_height = unit(4, "cm"), title_position = "topcenter")


heatmap_col_fun = colorRamp2(c(min(heatmap_data), 0, max(heatmap_data)), c("purple", "black", "yellow"))
heatmap_lgd = Legend(col_fun = heatmap_col_fun, title = "z-score", legend_height = unit(4, "cm"), title_position = "topcenter")

lgd_list <- packLegend(heatmap_lgd, gene_lgd, column_gap = unit(1,"cm"), direction = "horizontal")

split <- c()
for (clust in colnames(heatmap_data)){
  n_split <- as.numeric(as.character(annotation_info["n_genes", clust]))
  split <- c(split, rep(gsub("Cluster ", "", clust), n_split))
}
split <- factor(split, levels = as.character(unique(split)))

# Make block annotation
left_annotation =   HeatmapAnnotation(blk = anno_block(gp = gpar(fill = clust_cols, col = clust_cols)), which = "row", width = unit(1.5, "mm"))

heatmap_list = Heatmap(heatmap_data, name = "z-score", col = heatmap_col_fun, cluster_rows = FALSE, cluster_columns = TRUE,
                       cluster_row_slices = FALSE, row_km = 1, cluster_column_slices = FALSE,
                       clustering_method_columns = "ward.D2", clustering_distance_columns = "euclidean",
                       column_dend_reorder = FALSE, top_annotation = gene_bar, show_heatmap_legend = FALSE,
                       column_names_gp = gpar(col = clust_cols, fontface = "bold"),
                       split = split, left_annotation = left_annotation, show_column_names = FALSE) +
  rowAnnotation(link = anno_mark(at = match(annotation_genes, rownames(heatmap_data)),labels = annotation_genes,
                                 labels_gp = gpar(col = gene_cols, fontsize = 8, fontface = "bold")))

draw(heatmap_list, heatmap_legend_list =lgd_list, padding = unit(c(0.5, 0.5, 2, 2), "cm"), cluster_rows = FALSE,
     cluster_row_slices = FALSE)

Beyond looking at unbiased markers, we wanted to use legacy knowledge to better understand the biological role of our clusters. Using dot plots, we can see some cannonical genes expressed by the clusters.

MNP_gene_dict = {
    "monocyte" : ["CD14", "FCGR3A", "S100A8", "C5AR1", "FCN1", "VCAN", "CCR1", "CCR2"],
    "cDC" : ["IRF4", "CD1C", "CD1E", "FCER1A", "CLEC10A", "IRF8", "BATF3", "CLEC9A", "XCR1", "CCR7", "LAMP3",
             "AXL", "PPP1R14A", "TCF4", "GZMB", "JCHAIN", "IL3RA"],
    "macrophage" : ["MARCO", "MSR1", "VSIG4", "C1QA", "C1QB", "C1QC", "PPARG", "CD163"],
    "macrophage_specialization" : ["FABP4", "APOE", "VEGFA", "AREG", "SPP1", "MERTK", "CCL2", "CCL3", "CCL4",
                                   "CXCL9", "CXCL10", "CXCL11"]
}

dot_cmap = clr.LinearSegmentedColormap.from_list('gene_cmap', ["#d3d3d3", "#a5cde3", "#6dafd6", '#08306b'], N=200)

for key in MNP_gene_dict.keys() :
    plot = sc.pl.dotplot(mnp_harmonized, MNP_gene_dict[key], groupby = "new_clusters",
                  categories_order = ["11", "5", "6", "7", "13", "2", "14", "1", "8", "4", "9", "10", "3", "12"],
                         show = False, return_fig = True, title = "{key} markers".format(key = key),
                         cmap = dot_cmap, dot_max = 1)
    axes_dict = plot.get_axes()
    axes_dict["mainplot_ax"].set_axisbelow(True)
    axes_dict["mainplot_ax"].grid()
    plt.show()
    plt.close()

After inspecting all of the top marker genes for each cluster, we assigned the clusters specific annotations, which are shown below.


annotation_dict = {"1" : "MC1 (CXCL10)",
                   "2" : "MC2 (SPP1)",
                   "3" : "MC3 (AREG)",
                   "4" : "Mac1 (FABP4)",
                   "5" : "quiesMac",
                   "6" : "quiesMC",
                   "7" : "Cycling (PCLAF)",
                   "8" : "MC4 (CCR2)",
                   "9" : "Mac2 (A2M)",
                   "10" : "pDC (TCF4)",
                   "11" : "migDC (CCR7)",
                   "12" : "DC1 (CLEC9A)",
"13" : "DC2 (CD1C)",
"14" : "AS DC (AXL)"}


color_dict = dict(zip(annotation_dict.values(), mnp_harmonized.uns["new_clusters_colors"]))

mnp_harmonized.obs["annotation"] = [annotation_dict[c] for c in mnp_harmonized.obs["new_clusters"]]
figure = sc.pl.umap(mnp_harmonized, color = ["annotation"],
                    palette = color_dict, return_fig = True, show = False, legend_loc = "on data")
figure.set_size_inches(8, 8)
figure

```

Differential abundance analysis

Next we wanted to evaluate which clusters were associated with particular group:condition combinations. First, we can do this in a qualitative way, looking at the UMAP embedding density of the cells that are in our different categorical groups. The takeaway from these plots are that the cells are not uniformly distributed across the categorical groups and there are some biases.

mnp_harmonized.obs["pheno_tmpt"] = ["_".join([pheno, tmpt]) for pheno, tmpt in zip(mnp_harmonized.obs["phenotype"],
                                                                                   mnp_harmonized.obs["sample"])]
sc.tl.embedding_density(mnp_harmonized, basis = "umap", groupby = "pheno_tmpt", key_added = "pheno_tmpt_dens")
dens_plot = sc.pl.embedding_density(mnp_harmonized, basis = "umap", key = "pheno_tmpt_dens", ncols = 3,
                                    return_fig = True, show = False, wspace = 0.3)
dens_plot.set_size_inches(7, 5)
dens_plot

To look at the data in a more quantitative way, we used a mixed-effects association logistic regression model similar to that described by Fonseka et al. We fit a logistic regression modelfor each cell cluster. Each cluster was modelled independently as follows : cluster ~ 1 + condition:group + condition + group + (1 | id)

The least-squares means of the factors in the model were calculated and all pairwise contrasts between the means of the groups at each condition (e.g. AA vs AC within baseline, AA vs AC within allergen, etc.) were compared. The OR with confidence interval for each sample/condition combination were plotted.

cell_info = mnp_harmonized.obs

library(lme4)
library(ggplot2)
library(emmeans)

# A function to perform the mixed-effects logistic regression modelling, returning, p-values, odds-ratio and confidence intervals
get_abund_info <- function(dataset, cluster, contrast, random_effects = NULL, fixed_effects = NULL){
  # Generate design matrix from cluster assignments
  cluster <- as.character(cluster)
  designmat <- model.matrix(~ cluster + 0, data.frame(cluster = cluster))
  dataset <- cbind(designmat, dataset)
  # Create output list to hold results
  res <- vector(mode = "list", length = length(unique(cluster)))
  names(res) <- attributes(designmat)$dimnames[[2]]

  # Create model formulas
  model_rhs <- paste0(c(paste0(fixed_effects, collapse = " + "),
                        paste0("(1|", random_effects, ")", collapse = " + ")),
                      collapse = " + ")

  # Initialize list to store model objects for each cluster
  cluster_models <- vector(mode = "list",
                           length = length(attributes(designmat)$dimnames[[2]]))
  names(cluster_models) <- attributes(designmat)$dimnames[[2]]

  for (i in seq_along(attributes(designmat)$dimnames[[2]])) {
    test_cluster <- attributes(designmat)$dimnames[[2]][i]

    # Make it a non-intercept model to get odds for each variable
    full_fm <- as.formula(paste0(c(paste0(test_cluster, " ~ 1 + ", contrast, " + "),
                                   model_rhs), collapse = ""))

    full_model <- lme4::glmer(formula = full_fm, data = dataset,
                              family = binomial, nAGQ = 1, verbose = 0,
                              control = glmerControl(optimizer = "bobyqa"))

    pvals <-lsmeans(full_model, pairwise ~ "phenotype | sample")

    p_val_df <- summary(pvals$contrasts)
    p_val_df$cluster <- test_cluster

    ci <- eff_size(pvals, sigma = sigma(full_model), edf = df.residual(full_model))
    ci_df <- summary(ci) %>%
    dplyr::select(sample, asymp.LCL, asymp.UCL)
    ci_df$cluster <- test_cluster

    info_df <- left_join(p_val_df, ci_df, by = c("sample", "cluster"))

    cluster_models[[i]] <- info_df

  }
  return(cluster_models)
}

# A function to make a forest plot for the differential abundance analyses
plot_contrasts <- function(d, x_breaks_by = 1, wrap_ncol = 6, y_ord = FALSE) {
  # if (y_ord != FALSE){
  #   d$cluster <- factor(d$cluster, levels = y_ord)
  # }
  ggplot() +
    annotate(
      geom = "rect",
      xmin = -Inf,
      xmax = Inf,
      ymin = seq(from = 1, to = length(unique(d$cluster)), by = 2) - 0.5,
      ymax = seq(from = 1, to = length(unique(d$cluster)), by = 2) + 0.5,
      alpha = 0.2
    ) +
    geom_vline(xintercept = 0, size = 0.2) +
    geom_errorbarh(
      data = d,
      mapping = aes(
        xmin = asymp.LCL, xmax = asymp.UCL, y = cluster,
        color = sig
      ),
      height = 0
    ) +
    geom_point(
      data = d,
      mapping = aes(
        x = estimate, y = cluster,
        color = sig
      ),
      size = 3
    ) +
    scale_color_manual(
      name = "P < 0.05",
      values = c("#FF8000", "#40007F", "grey60"),
      breaks = c("AA", "ANA")
    ) +
    scale_x_continuous(
      breaks = log(c(0.125, 0.5, 1, 2, 4)),
      labels = function(x) exp(x)
    ) +
    scale_y_discrete(
      # expand = c(0, 0),
      # breaks = seq(1, length(unique(d$cluster))),
      labels = levels(plot_df$cluster),
    ) +
    # annotation_logticks(sides = "b") +
    expand_limits(y = c(0.5, length(unique(d$cluster)) + 0.5)) +
    # facet_grid(~ GeneName) +
    facet_wrap(~ sample, ncol = wrap_ncol) +
    theme_classic() +
    theme(
      strip.text = element_text(face = "italic"),
      text = element_text(size = 15)
    ) +
    labs(
      title = "contrasts by sample",
      x = "Odds Ratio",
      y = "cluster"
    )
}

clust_df <- reticulate::py$cell_info

abund_info <- get_abund_info(clust_df, cluster = clust_df$new_clusters,
                            contrast = "sample:phenotype",
                            random_effects = "id",
                            fixed_effects = c("sample", "phenotype"))

plot_df <- do.call(rbind, abund_info)
plot_df$direction <- ifelse(plot_df$estimate > 0, "AA", "ANA")

plot_df$cluster <- as.numeric(gsub("cluster", "", plot_df$cluster))
plot_df$sig <- ifelse(plot_df$p.value < 0.05, plot_df$direction, "non_sig")
cl_order <- c(11, 5, 6, 7, 13, 2, 14, 1, 8, 4, 9, 10, 3, 12)
plot_df$cluster <- factor(plot_df$cluster, levels = rev(cl_order))
plot_df$sample <- factor(plot_df$sample, levels = c("Pre", "Dil", "Ag"))
plot_contrasts(plot_df)